{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#1 导入数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepalLength</th>\n",
       "      <th>sepalWidth</th>\n",
       "      <th>petalLength</th>\n",
       "      <th>petalWidth</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepalLength  sepalWidth  petalLength  petalWidth        category\n",
       "0            5.1         3.5          1.4         0.2     Iris-setosa\n",
       "1            4.9         3.0          1.4         0.2     Iris-setosa\n",
       "2            4.7         3.2          1.3         0.2     Iris-setosa\n",
       "3            4.6         3.1          1.5         0.2     Iris-setosa\n",
       "4            5.0         3.6          1.4         0.2     Iris-setosa\n",
       "..           ...         ...          ...         ...             ...\n",
       "145          6.7         3.0          5.2         2.3  Iris-virginica\n",
       "146          6.3         2.5          5.0         1.9  Iris-virginica\n",
       "147          6.5         3.0          5.2         2.0  Iris-virginica\n",
       "148          6.2         3.4          5.4         2.3  Iris-virginica\n",
       "149          5.9         3.0          5.1         1.8  Iris-virginica\n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "names=['sepalLength','sepalWidth','petalLength','petalWidth','category']\n",
    "data=pd.read_csv('./iris.data',sep=',',names=names)\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#2.切分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xtrain,Xtest,Ytrain,Ytest =train_test_split(data.iloc[:,:-1]\n",
    "                                ,data.iloc[:,-1],test_size=0.3,random_state=360)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#3.使用标准化包，学习训练集的数据分布，从而对训练集和测试集来做标准化（20分）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "std=StandardScaler().fit(Xtrain)\n",
    "Xtrain_=std.fit_transform(Xtrain)\n",
    "Xtest_ =std.fit_transform(Xtest)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#4.在确定l2范式的情况下，使用网格搜索判断solver, C的最优组合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 0.4722222222222222, 'solver': 'sag'}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression as LR #逻辑回归\n",
    "#在l2范式下，判断C和solver的最优值\n",
    "p = {\n",
    "    'C':list(np.linspace(0.05,1,19)),\n",
    "    'solver':['liblinear','sag','newton-cg','lbfgs']\n",
    "}\n",
    "\n",
    "model = LR(penalty='l2',max_iter=10000)\n",
    "\n",
    "GS = GridSearchCV(model,p,cv=5)\n",
    "GS.fit(Xtrain_,Ytrain)\n",
    "GS.best_params_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#5.将最优的结果重新用来实例化模型，查看训练集和测试集下的分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9619047619047619, 0.9555555555555556)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#将最优参数重新用于实例化模型，查看训练集和测试集下的分数\n",
    "model = LR(penalty='l2',\n",
    "           max_iter=10000,\n",
    "           C=GS.best_params_['C'],\n",
    "           solver=GS.best_params_['solver'])  #sag 三种通过导数计算的方式是不能l1正则化的\n",
    "\n",
    "model.fit(Xtrain_,Ytrain)\n",
    "model.score(Xtrain_,Ytrain),model.score(Xtest_,Ytest)\n",
    "           "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#6.计算精准率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "    Iris-setosa       1.00      1.00      1.00        17\n",
      "Iris-versicolor       0.93      0.93      0.93        14\n",
      " Iris-virginica       0.93      0.93      0.93        14\n",
      "\n",
      "       accuracy                           0.96        45\n",
      "      macro avg       0.95      0.95      0.95        45\n",
      "   weighted avg       0.96      0.96      0.96        45\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9555555555555556"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import precision_score\n",
    "Ypredict=model.predict(Xtest_)\n",
    "\n",
    "report=classification_report(Ytest,Ypredict)\n",
    "print(report)\n",
    "precision_score(Ytest,Ypredict,average='micro')"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "7ef57b1a12f99e7c74e468f4512270929d1986b547266cbc198c69385640fc98"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
